Skip to content

feat: implement column parallel for lm head to improve performance.#1145

Merged
XuZhang99 merged 3 commits intojd-opensource:mainfrom
wxh571001500:main
Apr 3, 2026
Merged

feat: implement column parallel for lm head to improve performance.#1145
XuZhang99 merged 3 commits intojd-opensource:mainfrom
wxh571001500:main

Conversation

@wxh571001500
Copy link
Copy Markdown
Contributor

lmhead performance improved by 3%

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the LmHead to utilize ColumnParallelLinearImpl and introduces support for vocabulary padding to ensure alignment during tensor parallel operations. The review feedback identifies several critical issues: a regression caused by hardcoding quantization arguments in a general-purpose linear layer constructor, memory inefficiencies when sharding padded tensors, and a potential bug in state dict lookups using incorrect keys. Additionally, there is a recommendation to deduplicate the vocabulary padding calculation logic into a shared utility to improve maintainability.

RobbieLeung
RobbieLeung previously approved these changes Mar 31, 2026
Copy link
Copy Markdown
Collaborator

@RobbieLeung RobbieLeung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XuZhang99 XuZhang99 merged commit 87d9e35 into jd-opensource:main Apr 3, 2026
22 of 39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants